""" Match samples based on observed distribution of prizes

    The original maze game was written in Python 2.7, so
    pickled games also require Python 2.7 and a specific
    runtime environment.

    It is hard to setup the original python environment
    to de-pickle the experiments. We provide the code
    here for inspection, and provide the necessary data
    files to replicate the analysis in /data/ directory.
"""

# IMPORTS  ==========================================================================================================================

import os
import sys
import csv
import random
import pickle
from math import ceil
from collections import defaultdict
from collections import OrderedDict

# SETTINGS  ==========================================================================================================================

CONFIG_CONTROL                      = 10
CONFIG_TREATED                      = 12
PATH_TO_GAMES                       = './data/raw_games/'
FNAME_MATCHED_SIGNALS               = 'processes_data/matched-signals.csv'
FNAME_MATCHED_SIGNALS_RATE          = 'processes_data/matched-signals-match-rate.csv'
FNAME_MATCHED_SIGNALS_MOVES         = 'processes_data/matched-signals-moves.csv'
FNAME_MATCHED_SIGNALS_MOVES_RATE    = 'processes_data/matched-signals-moves-match-rate.csv'


# SUPPORTING FUNCTIONS  ===============================================================================================================

def progbar(x, label='', width=50):
    """ Progbar with x = ratio completed"""
    try:
        if label: label = str(label).strip() + ': '
        x = x * float(100)
        w = min(100, width - 9)
        p = min(100, int(ceil(float(x * w) / float(100))))
        sys.stdout.write( '\r' + label  + (str("%.1f" % round(x, 1)).rjust(4, ' ') + '%') + ' [' + '#' * p + '-' * (w - p) + ']' )
        sys.stdout.flush()
        if x >= 100.0: print()
    except Exception as e: print(e)

def save_csv(lst, fname):
    fname = str(fname.strip())
    if os.path.exists(fname): os.remove(fname)
    with open(fname, 'w') as f:
        writer = csv.writer(f, dialect='excel')
        writer.writerows(lst)

def load_games(path_to_games):
    print()
    fnames = [f for f in os.listdir(path_to_games) if os.path.isfile(path_to_games + f)]
    games  = list()
    for fname in fnames:
        with open(path_to_games + fname, 'r') as f:
            games.append(pickle.loads(f.read()))
        progbar(float(len(games))/float(len(fnames)), label = 'Loading games')
    assert games, "No games"
    print('  Total games:  '.ljust(20) + str(len(games)))

    # remove ugames not completed
    games = [g for g in games if g.completed]
    print('  Completed games: '.ljust(20) + str(len(games)))

    print()
    return games


# MAIN FUNCTION  ==========================================================================================================================================

def main():

    print()
    print('========================================================================')
    print('      Match observations cycle-by-cycle for treated vs. control         ')
    print('========================================================================')
    print()

    # Init Vars
    random.seed(8768687)  # set seed so we can replicate results for random choices
    EXPOSURE = defaultdict(list)  # [signal]      = list of eids having been exposed to that signal
    DECISIONS = defaultdict(OrderedDict)  # [eid][signal] = deicison taken
    ACTIONS = defaultdict(dict)  # [eid][signal] = num ACTIONS into game at that signal
    MOVES = defaultdict(dict)  # [eid][signal] = num MOVES into game at that signal

    # Load Games
    GAME_OBJECTS = load_games(PATH_TO_GAMES)
    GAME_OBJECTS = {g.eid: g for g in GAME_OBJECTS if (g.config == CONFIG_CONTROL or g.config == CONFIG_TREATED)}
    IDS = {eid: i for i, eid in list(enumerate(sorted({eid for eid in GAME_OBJECTS.keys()})))}

    # Preprocess Game Data to Use for Analysis =========================================
    count = 0
    for eid, g in GAME_OBJECTS.items():

        progbar(float(count) / float(len(GAME_OBJECTS)), label='Processing games')

        top_open = False
        signal = [0, 0, 0, 0]
        moves_taken = 0
        actions_taken = 0
        EXPOSURE[tuple(signal)].append(eid)

        for (x0, x1, x2, x3, x4, x5, x6, row, col, prize_row, prize_col,
             at_prize, at_start, at_door) in g.lst_history:

            actions_taken += 1

            if row != 3 or col != 3:
                moves_taken += 1

            # Passed through top door
            if row == 6:
                top_open = True

            # Observed a signal
            if row == 3 and col == 3 and top_open:

                # update signal
                for j in range(len(g.prize_positions)):
                    if (prize_row, prize_col) == g.prize_positions[j]:
                        signal[j] += 1

                # track EXPOSURE
                EXPOSURE[tuple(signal)].append(eid)

            # Track decision
            if top_open and col == 3 and (row == 2 or row == 4):

                # determine decision
                decision = None
                if row == 4: decision = 0  # move UP
                if row == 2: decision = 1  # move DOWN
                assert decision is not None, 'Failed consistency check: decision cannot be none'

                # track decision for that signal
                if tuple(signal) not in DECISIONS[eid]:
                    DECISIONS[eid][tuple(signal)] = list()
                DECISIONS[eid][tuple(signal)].append(decision)

                # track number of MOVES taken at that signal
                if tuple(signal) not in MOVES[eid]: MOVES[eid][tuple(signal)] = 0
                MOVES[eid][tuple(signal)] = moves_taken

                # track number of ACTIONS taken at that signal
                if tuple(signal) not in ACTIONS[eid]: ACTIONS[eid][tuple(signal)] = 0
                ACTIONS[eid][tuple(signal)] = actions_taken

        count += 1
    print()
    print('\t\t# Subjects: ' + str(len(DECISIONS.keys())))
    print('\t\t# Signals:  ' + str(len(EXPOSURE.keys())))
    print()

    # Matching Function used within this outer function
    def match_sample(also_match_by_window=False, window_size=5):

        dataset = list()
        hits = defaultdict(int)
        miss = defaultdict(int)

        eids_tx = {e for e, g in GAME_OBJECTS.items() if g.config == CONFIG_TREATED}
        eids_ctrl = {e for e, g in GAME_OBJECTS.items() if g.config == CONFIG_CONTROL}

        for eid_tx in sorted(eids_tx):

            print('\n--------------------------------------------------------')
            print('\nGAME ' + str(eid_tx))
            print('\n--------------------------------------------------------')

            cycle = 1

            for signal in DECISIONS[eid_tx].keys():

                print('\nCYCLE: ' + str(cycle))

                if also_match_by_window:
                    matched_controls = [e for e in EXPOSURE[signal]
                                        if e not in eids_tx
                                        and e in eids_ctrl
                                        and signal in DECISIONS[e]
                                        and abs(
                            ACTIONS[eid_tx][tuple(signal)] - ACTIONS[e][tuple(signal)]) <= window_size
                                        ]


                else:
                    matched_controls = [e for e in EXPOSURE[signal]
                                        if e not in eids_tx
                                        and e in eids_ctrl
                                        and signal in DECISIONS[e]]

                # note that in some cases the game will end right after being
                # exposed to a signal but before a decision is made;
                # therefore, check that signal in DECISIONS[e]

                print('\tMatches:'.ljust(25) + str(len(matched_controls)))

                if matched_controls:

                    # select a random match
                    eid_ctrl = random.choice(matched_controls)

                    # get decision (set to 1 if *EVER* explored this signal)
                    dv_tx = max(DECISIONS[eid_tx][signal])
                    dv_ctrl = max(DECISIONS[eid_ctrl][signal])

                    # get number of MOVES into game at this point
                    move_num_tx = MOVES[eid_tx][tuple(signal)]
                    move_num_ctrl = MOVES[eid_ctrl][tuple(signal)]

                    # add to dataset (matched, cycle, move_num, eid_tx, eid_ctrl, treated, dv)
                    dataset.append((1, cycle, move_num_tx, eid_tx, eid_ctrl, 1, dv_tx))
                    dataset.append((1, cycle, move_num_ctrl, eid_tx, eid_ctrl, 0, dv_ctrl))

                    hits[move_num_tx] += 1

                    print('\tSelected:' + str(eid_ctrl))
                    print('\tTx:      ' + str(DECISIONS[eid_tx][tuple(signal)]))
                    print('\tCtrl:    ' + str(DECISIONS[eid_ctrl][tuple(signal)]))
                    print('\tDV Tx:   ' + str(max(DECISIONS[eid_tx][tuple(signal)])))
                    print('\tDV Ctrl: ' + str(max(DECISIONS[eid_ctrl][tuple(signal)])))

                else:
                    move_num_tx = MOVES[eid_tx][tuple(signal)]
                    miss[move_num_tx] += 1
                    dataset.append((0, cycle, '', eid_tx, '', 1, dv_tx))
                    dataset.append((0, cycle, '', eid_tx, '', 0, dv_ctrl))
                    print('\t~~~~~~~~~~ NONE ~~~~~~~~~~')

                cycle += 1

        match_rate = []
        for m in range(1, 501):
            total = hits[m] + miss[m]
            if total:
                rate = float(hits[m]) / total
                match_rate.append((m, rate))

        return dataset, match_rate

    # Match observations by signal only ===============================================
    dataset, match_rate = match_sample()
    match_rate = [('moves', 'rate'), ] + match_rate
    csv = [('matched', 'signal', 'moves', 'eid', 'eid_tx', 'eid_ctrl', 'treated', 'dv'), ]
    for matched, cycle, move_num, eid_tx, eid_ctrl, treated, dv in dataset:
        # note that we convert the long eid to a simple ID so STAT can cluster on it. That failes for eid_ctrl when there is not match
        if matched:
            simple_id_ctrl = IDS[eid_ctrl]
        else:
            simple_id_ctrl = ''
        csv.append((matched, cycle, move_num, eid_tx, IDS[eid_tx], simple_id_ctrl, treated, dv))
    save_csv(csv, FNAME_MATCHED_SIGNALS)
    save_csv(match_rate, FNAME_MATCHED_SIGNALS_RATE)

    print('SAVED matched dataset by signals')

    # Match observations by signals and MOVES  ===========================================
    dataset, match_rate = match_sample(also_match_by_window=True, window_size=1)
    match_rate = [('moves', 'rate'), ] + match_rate
    csv = [('matched', 'signal', 'moves', 'eid', 'eid_tx', 'eid_ctrl', 'treated', 'dv'), ]
    for matched, cycle, move_num, eid_tx, eid_ctrl, treated, dv in dataset:
        if matched:
            simple_id_ctrl = IDS[eid_ctrl]
        else:
            simple_id_ctrl = ''
        csv.append((matched, cycle, move_num, eid_tx, IDS[eid_tx], simple_id_ctrl, treated, dv))
    save_csv(csv, FNAME_MATCHED_SIGNALS_MOVES)
    save_csv(match_rate, FNAME_MATCHED_SIGNALS_MOVES_RATE)

    print('SAVED matched dataset by signals and MOVES')


# COMMANDLINE ====================================================================================================================
if __name__ == "__main__":

    # THIS FILES REQUIRES THE ORIGINAL PYTHON 2 ENVIRONMENT TO RUN
    print('###########################################################')
    print('     THIS CODE REQUIRES A SPECIFIC ENVIRONMENT TO RUN      ')
    print('  WE PROVIDE THE RAW DATA FILES FOR THE MATCHED SAMPLES    ')
    print('###########################################################')
    print(__doc__)
    print('###########################################################')

    # Uncomment the following line to run this package in the original environment
    #main()

